"""
    Title

    author: wxz
    date: 
    github: https://github.com/xinzwang
"""

import numpy as np
from isp.blocks.utils import *


def demosaik(img):
    G = fill_channel_directional_weight(img, 'rggb')
    B, R = fill_br_locations(img, G, 'rggb')

    height, width = img.shape
    out = np.empty((height, width, 3), dtype=np.float32)

    out[:, :, 0] = R
    out[:, :, 1] = G
    out[:, :, 2] = B

    out = np.clip(out,0,1)
    return out


def fill_channel_directional_weight(data, bayer_pattern):
    # == Calculate the directional weights (weight_N, weight_E, weight_S, weight_W.
    # where N, E, S, W stand for north, east, south, and west.)
    data = np.asarray(data)
    v = np.asarray(signal.convolve2d(data, [[1], [0], [-1]], mode="same", boundary="symm"))
    h = np.asarray(signal.convolve2d(data, [[1, 0, -1]], mode="same", boundary="symm"))

    weight_N = np.zeros(np.shape(data), dtype=np.float32)
    weight_E = np.zeros(np.shape(data), dtype=np.float32)
    weight_S = np.zeros(np.shape(data), dtype=np.float32)
    weight_W = np.zeros(np.shape(data), dtype=np.float32)

    value_N = np.zeros(np.shape(data), dtype=np.float32)
    value_E = np.zeros(np.shape(data), dtype=np.float32)
    value_S = np.zeros(np.shape(data), dtype=np.float32)
    value_W = np.zeros(np.shape(data), dtype=np.float32)

    if ((bayer_pattern == "rggb") or (bayer_pattern == "bggr")):

        # note that in the following the locations in the comments are given
        # assuming the bayer_pattern rggb

        # == CALCULATE WEIGHTS IN B LOCATIONS
        weight_N[1::2, 1::2] = np.abs(v[1::2, 1::2]) + np.abs(v[::2, 1::2])

        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp_h_b = np.hstack((h, np.atleast_2d(h[:, -2]).T))
        weight_E[1::2, 1::2] = np.abs(h[1::2, 1::2]) + np.abs(temp_h_b[1::2, 2::2])

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp_v_b = np.vstack((v, v[-1]))
        weight_S[1::2, 1::2] = np.abs(v[1::2, 1::2]) + np.abs(temp_v_b[2::2, 1::2])
        weight_W[1::2, 1::2] = np.abs(h[1::2, 1::2]) + np.abs(h[1::2, ::2])

        # == CALCULATE WEIGHTS IN R LOCATIONS
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp_v_r = np.delete(np.vstack((v[1], v)), -1, 0)
        weight_N[::2, ::2] = np.abs(v[::2, ::2]) + np.abs(temp_v_r[::2, ::2])

        weight_E[::2, ::2] = np.abs(h[::2, ::2]) + np.abs(h[::2, 1::2])

        weight_S[::2, ::2] = np.abs(v[::2, ::2]) + np.abs(v[1::2, ::2])

        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp_h_r = np.delete(np.hstack((np.atleast_2d(h[:, 1]).T, h)), -1, 1)
        weight_W[::2, ::2] = np.abs(h[::2, ::2]) + np.abs(temp_h_r[::2, ::2])

        weight_N = np.divide(1., 1. + weight_N)
        weight_E = np.divide(1., 1. + weight_E)
        weight_S = np.divide(1., 1. + weight_S)
        weight_W = np.divide(1., 1. + weight_W)

        # == CALCULATE DIRECTIONAL ESTIMATES IN B LOCATIONS
        value_N[1::2, 1::2] = data[::2, 1::2] + v[::2, 1::2] / 2.

        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        value_E[1::2, 1::2] = temp[1::2, 2::2] - temp_h_b[1::2, 2::2] / 2.

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((data, data[-1]))
        value_S[1::2, 1::2] = temp[2::2, 1::2] - temp_v_b[2::2, 1::2] / 2.

        value_W[1::2, 1::2] = data[1::2, ::2] + h[1::2, ::2] / 2.

        # == CALCULATE DIRECTIONAL ESTIMATES IN R LOCATIONS
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        value_N[::2, ::2] = temp[::2, ::2] + temp_v_r[::2, ::2] / 2.

        value_E[::2, ::2] = data[::2, 1::2] - h[::2, 1::2] / 2.

        value_S[::2, ::2] = data[1::2, ::2] - v[1::2, ::2] / 2.

        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(data[:, 1]).T, data)), -1, 1)
        value_W[::2, ::2] = temp[::2, ::2] + temp_h_r[::2, ::2] / 2.

        output = np.zeros(np.shape(data), dtype=np.float32)
        output = np.divide((np.multiply(value_N, weight_N) + \
                            np.multiply(value_E, weight_E) + \
                            np.multiply(value_S, weight_S) + \
                            np.multiply(value_W, weight_W)), \
                           (weight_N + weight_E + weight_S + weight_W))

        output[::2, 1::2] = data[::2, 1::2]
        output[1::2, ::2] = data[1::2, ::2]

        return output

    elif ((bayer_pattern == "gbrg") or (bayer_pattern == "grbg")):

        # note that in the following the locations in the comments are given
        # assuming the bayer_pattern gbrg

        # == CALCULATE WEIGHTS IN B LOCATIONS
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp_v_b = np.delete(np.vstack((v[1], v)), -1, 0)
        weight_N[::2, 1::2] = np.abs(v[::2, 1::2]) + np.abs(temp_v_b[::2, 1::2])

        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp_h_b = np.hstack((h, np.atleast_2d(h[:, -2]).T))
        weight_E[::2, 1::2] = np.abs(h[::2, 1::2]) + np.abs(temp_h_b[::2, 2::2])

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        weight_S[::2, 1::2] = np.abs(v[::2, 1::2]) + np.abs(v[1::2, 1::2])
        weight_W[::2, 1::2] = np.abs(h[::2, 1::2]) + np.abs(h[::2, ::2])

        # == CALCULATE WEIGHTS IN R LOCATIONS
        weight_N[1::2, ::2] = np.abs(v[1::2, ::2]) + np.abs(v[::2, ::2])
        weight_E[1::2, ::2] = np.abs(h[1::2, ::2]) + np.abs(h[1::2, 1::2])

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp_v_r = np.vstack((v, v[-1]))
        weight_S[1::2, ::2] = np.abs(v[1::2, ::2]) + np.abs(temp_v_r[2::2, ::2])

        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp_h_r = np.delete(np.hstack((np.atleast_2d(h[:, 1]).T, h)), -1, 1)
        weight_W[1::2, ::2] = np.abs(h[1::2, ::2]) + np.abs(temp_h_r[1::2, ::2])

        weight_N = np.divide(1., 1. + weight_N)
        weight_E = np.divide(1., 1. + weight_E)
        weight_S = np.divide(1., 1. + weight_S)
        weight_W = np.divide(1., 1. + weight_W)

        # == CALCULATE DIRECTIONAL ESTIMATES IN B LOCATIONS
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        value_N[::2, 1::2] = temp[::2, 1::2] + temp_v_b[::2, 1::2] / 2.

        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        value_E[::2, 1::2] = temp[::2, 2::2] - temp_h_b[::2, 2::2] / 2.

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        value_S[::2, 1::2] = data[1::2, 1::2] - v[1::2, 1::2] / 2.

        value_W[::2, 1::2] = data[::2, ::2] + h[::2, ::2] / 2.

        # == CALCULATE DIRECTIONAL ESTIMATES IN R LOCATIONS
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        value_N[1::2, ::2] = data[::2, ::2] + v[::2, ::2] / 2.
        value_E[1::2, ::2] = data[1::2, 1::2] - h[1::2, 1::2] / 2.

        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((data, data[-1]))
        value_S[1::2, ::2] = temp[2::2, ::2] - temp_v_r[2::2, ::2] / 2.

        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(data[:, 1]).T, data)), -1, 1)
        value_W[1::2, ::2] = temp[1::2, ::2] + temp_h_r[1::2, ::2] / 2.

        output = np.zeros(np.shape(data), dtype=np.float32)
        output = np.divide((np.multiply(value_N, weight_N) + \
                            np.multiply(value_E, weight_E) + \
                            np.multiply(value_S, weight_S) + \
                            np.multiply(value_W, weight_W)), \
                           (weight_N + weight_E + weight_S + weight_W))

        output[::2, ::2] = data[::2, ::2]
        output[1::2, 1::2] = data[1::2, 1::2]

        return output


def fill_br_locations(data, G, bayer_pattern):
    # Fill up the B/R values interpolated at R/B locations
    B = np.zeros(np.shape(data), dtype=np.float32)
    R = np.zeros(np.shape(data), dtype=np.float32)

    data = np.asarray(data)
    G = np.asarray(G)
    d1 = np.asarray(signal.convolve2d(data, [[-1, 0, 0], [0, 0, 0], [0, 0, 1]], mode="same", boundary="symm"))
    d2 = np.asarray(signal.convolve2d(data, [[0, 0, 1], [0, 0, 0], [-1, 0, 0]], mode="same", boundary="symm"))

    df_NE = np.asarray(signal.convolve2d(G, [[0, 0, 0], [0, 1, 0], [-1, 0, 0]], mode="same", boundary="symm"))
    df_SE = np.asarray(signal.convolve2d(G, [[-1, 0, 0], [0, 1, 0], [0, 0, 0]], mode="same", boundary="symm"))
    df_SW = np.asarray(signal.convolve2d(G, [[0, 0, -1], [0, 1, 0], [0, 0, 0]], mode="same", boundary="symm"))
    df_NW = np.asarray(signal.convolve2d(G, [[0, 0, 0], [0, 1, 0], [0, 0, -1]], mode="same", boundary="symm"))

    weight_NE = np.zeros(np.shape(data), dtype=np.float32)
    weight_SE = np.zeros(np.shape(data), dtype=np.float32)
    weight_SW = np.zeros(np.shape(data), dtype=np.float32)
    weight_NW = np.zeros(np.shape(data), dtype=np.float32)

    value_NE = np.zeros(np.shape(data), dtype=np.float32)
    value_SE = np.zeros(np.shape(data), dtype=np.float32)
    value_SW = np.zeros(np.shape(data), dtype=np.float32)
    value_NW = np.zeros(np.shape(data), dtype=np.float32)

    if ((bayer_pattern == "rggb") or (bayer_pattern == "bggr")):

        # == weights for B in R locations
        weight_NE[::2, ::2] = np.abs(d2[::2, ::2]) + np.abs(df_NE[::2, ::2])
        weight_SE[::2, ::2] = np.abs(d1[::2, ::2]) + np.abs(df_SE[::2, ::2])
        weight_SW[::2, ::2] = np.abs(d2[::2, ::2]) + np.abs(df_SW[::2, ::2])
        weight_NW[::2, ::2] = np.abs(d1[::2, ::2]) + np.abs(df_NW[::2, ::2])

        # == weights for R in B locations
        weight_NE[1::2, 1::2] = np.abs(d2[1::2, 1::2]) + np.abs(df_NE[1::2, 1::2])
        weight_SE[1::2, 1::2] = np.abs(d1[1::2, 1::2]) + np.abs(df_SE[1::2, 1::2])
        weight_SW[1::2, 1::2] = np.abs(d2[1::2, 1::2]) + np.abs(df_SW[1::2, 1::2])
        weight_NW[1::2, 1::2] = np.abs(d1[1::2, 1::2]) + np.abs(df_NW[1::2, 1::2])

        weight_NE = np.divide(1., 1. + weight_NE)
        weight_SE = np.divide(1., 1. + weight_SE)
        weight_SW = np.divide(1., 1. + weight_SW)
        weight_NW = np.divide(1., 1. + weight_NW)

        # == directional estimates of B in R locations
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        value_NE[::2, ::2] = temp[::2, 1::2] + df_NE[::2, ::2] / 2.
        value_SE[::2, ::2] = data[1::2, 1::2] + df_SE[::2, ::2] / 2.
        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(data[:, 1]).T, data)), -1, 1)
        value_SW[::2, ::2] = temp[1::2, ::2] + df_SW[::2, ::2] / 2.

        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(temp[:, 1]).T, temp)), -1, 1)
        value_NW[::2, ::2] = temp[::2, ::2] + df_NW[::2, ::2]

        # == directional estimates of R in B locations
        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        value_NE[1::2, 1::2] = temp[::2, 2::2] + df_NE[1::2, 1::2] / 2.
        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((temp, temp[-1]))
        value_SE[1::2, 1::2] = temp[2::2, 2::2] + df_SE[1::2, 1::2] / 2.
        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((data, data[-1]))
        value_SW[1::2, 1::2] = temp[2::2, ::2] + df_SW[1::2, 1::2] / 2.
        value_NW[1::2, 1::2] = data[::2, ::2] + df_NW[1::2, 1::2] / 2.

        RB = np.divide(np.multiply(weight_NE, value_NE) + \
                       np.multiply(weight_SE, value_SE) + \
                       np.multiply(weight_SW, value_SW) + \
                       np.multiply(weight_NW, value_NW), \
                       (weight_NE + weight_SE + weight_SW + weight_NW))

        if (bayer_pattern == "rggb"):

            R[1::2, 1::2] = RB[1::2, 1::2]
            R[::2, ::2] = data[::2, ::2]
            B[::2, ::2] = RB[::2, ::2]
            B[1::2, 1::2] = data[1::2, 1::2]

        elif (bayer_pattern == "bggr"):
            R[::2, ::2] = RB[::2, ::2]
            R[1::2, 1::2] = data[1::2, 1::2]
            B[1::2, 1::2] = RB[1::2, 1::2]
            B[::2, ::2] = data[::2, ::2]

        R[1::2, ::2] = G[1::2, ::2]
        R[::2, 1::2] = G[::2, 1::2]
        R = fill_channel_directional_weight(R, "gbrg")

        B[1::2, ::2] = G[1::2, ::2]
        B[::2, 1::2] = G[::2, 1::2]
        B = fill_channel_directional_weight(B, "gbrg")


    elif ((bayer_pattern == "grbg") or (bayer_pattern == "gbrg")):
        # == weights for B in R locations
        weight_NE[::2, 1::2] = np.abs(d2[::2, 1::2]) + np.abs(df_NE[::2, 1::2])
        weight_SE[::2, 1::2] = np.abs(d1[::2, 1::2]) + np.abs(df_SE[::2, 1::2])
        weight_SW[::2, 1::2] = np.abs(d2[::2, 1::2]) + np.abs(df_SW[::2, 1::2])
        weight_NW[::2, 1::2] = np.abs(d1[::2, 1::2]) + np.abs(df_NW[::2, 1::2])

        # == weights for R in B locations
        weight_NE[1::2, ::2] = np.abs(d2[1::2, ::2]) + np.abs(df_NE[1::2, ::2])
        weight_SE[1::2, ::2] = np.abs(d1[1::2, ::2]) + np.abs(df_SE[1::2, ::2])
        weight_SW[1::2, ::2] = np.abs(d2[1::2, ::2]) + np.abs(df_SW[1::2, ::2])
        weight_NW[1::2, ::2] = np.abs(d1[1::2, ::2]) + np.abs(df_NW[1::2, ::2])

        weight_NE = np.divide(1., 1. + weight_NE)
        weight_SE = np.divide(1., 1. + weight_SE)
        weight_SW = np.divide(1., 1. + weight_SW)
        weight_NW = np.divide(1., 1. + weight_NW)

        # == directional estimates of B in R locations
        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((temp, np.atleast_2d(temp[:, -2]).T))
        value_NE[::2, 1::2] = temp[::2, 2::2] + df_NE[::2, 1::2] / 2.
        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        value_SE[::2, 1::2] = temp[1::2, 2::2] + df_SE[::2, 1::2] / 2.
        value_SW[::2, 1::2] = data[1::2, ::2] + df_SW[::2, 1::2] / 2.

        # repeating the second row at the top of matrix so that sampling does
        # not cause any dimension mismatch, also remove the bottom row
        temp = np.delete(np.vstack((data[1], data)), -1, 0)
        value_NW[::2, 1::2] = temp[::2, ::2] + df_NW[::2, 1::2]

        # == directional estimates of R in B locations
        value_NE[1::2, ::2] = data[::2, 1::2] + df_NE[1::2, ::2] / 2.
        # repeating the column before the last to the right so that sampling
        # does not cause any dimension mismatch
        temp = np.hstack((data, np.atleast_2d(data[:, -2]).T))
        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((temp, temp[-1]))
        value_SE[1::2, ::2] = temp[2::2, 1::2] + df_SE[1::2, ::2] / 2.
        # repeating the row before the last row to the bottom so that sampling
        # does not cause any dimension mismatch
        temp = np.vstack((data, data[-1]))
        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(temp[:, 1]).T, temp)), -1, 1)
        value_SW[1::2, ::2] = temp[2::2, ::2] + df_SW[1::2, ::2] / 2.
        # repeating the second column at the left of matrix so that sampling
        # does not cause any dimension mismatch, also remove the rightmost
        # column
        temp = np.delete(np.hstack((np.atleast_2d(data[:, 1]).T, data)), -1, 1)
        value_NW[1::2, ::2] = temp[::2, ::2] + df_NW[1::2, ::2] / 2.

        RB = np.divide(np.multiply(weight_NE, value_NE) + \
                       np.multiply(weight_SE, value_SE) + \
                       np.multiply(weight_SW, value_SW) + \
                       np.multiply(weight_NW, value_NW), \
                       (weight_NE + weight_SE + weight_SW + weight_NW))

        if (bayer_pattern == "grbg"):

            R[1::2, ::2] = RB[1::2, ::2]
            R[::2, 1::2] = data[::2, 1::2]
            B[::2, 1::2] = RB[::2, 1::2]
            B[1::2, ::2] = data[1::2, ::2]

        elif (bayer_pattern == "gbrg"):
            R[::2, 1::2] = RB[::2, 1::2]
            R[1::2, ::2] = data[1::2, ::2]
            B[1::2, ::2] = RB[1::2, ::2]
            B[::2, 1::2] = data[::2, 1::2]

        R[::2, ::2] = G[::2, ::2]
        R[1::2, 1::2] = G[1::2, 1::2]
        R = fill_channel_directional_weight(R, "rggb")

        B[1::2, 1::2] = G[1::2, 1::2]
        B[::2, ::2] = G[::2, ::2]
        B = fill_channel_directional_weight(B, "rggb")

    return B, R